import torch.nn as nn

class TokenEmbedding(nn.Embedding):
    def __init__(self,vocab_size,embed_size=512):
        super().__init__(vocab_size,embed_size,padding_idx=0)